import warnings

import torch
import numpy as np
import matplotlib.pyplot as plt

from io import BytesIO
from IPython.display import display, Image

class Plotter:
    """
    Class for plotting progress during training of the GFN.
    """

    def __init__(self, config, env, gfn):
        self.config = config
        self.env = env
        self.gfn = gfn
        
        self.device = self.env.device

        self.global_data = []
        
        # Mosaic plots for visualisation
        if self.env.dim == 1:
            self.replay_buffer_mosaic = {"mosaic": [["rb_hist"], ["rb_traj"]], "figsize": (10, 5)}
            self.onpolicy_mosaic = {"mosaic": [["on_policy_hist"], ["on_policy_traj"]], "figsize": (10, 5)}
            self.metadynamics_mosaic = {"mosaic": [["metad"]], "figsize": (5, 5)}
            self.sampler_mosaic = {"mosaic": [["loss", "op_hist", "rb_hist", "metad", "forward_flow", "flow"],
                                              ["Z", "op_traj", "rb_traj", "metad", "forward_flow", "flow"],
                                              ["noise_schedule", "op_error", "empty", "empty", "empty", "empty"]],
                                   "figsize": (17, 8)}
        elif self.env.dim == 2 or self.env.dim == 3 or self.env.dim == 4:
            rb_plot_names = ["rb_{}".format(i) for i in range(self.env.dim)]
            rb_traj_plot_names = ["rb_traj_{}".format(i) for i in range(self.env.dim)]
            op_plot_names = ["op_{}".format(i) for i in range(self.env.dim)]
            op_traj_plot_names = ["op_traj_{}".format(i) for i in range(self.env.dim)]
            self.replay_buffer_mosaic = {"mosaic": [["rb_hist", "none"], rb_plot_names, rb_traj_plot_names], "figsize": (15, 10)}
            self.onpolicy_mosaic = {"mosaic": [["op_hist", "none"], op_plot_names, op_traj_plot_names], "figsize": (15, 10)}
            self.metadynamics_mosaic = {"mosaic": [["samples", "kde_potential", "bias"]], "figsize": (15, 5)}

            if self.env.dim == 2:
                self.sampler_mosaic = {"mosaic": [["samples_on_potential", "kde_potential", "kde_potential_prime1", "kde_potential_prime2", "bias_potential", "noise_schedule"],
                                                ["rb_hist", "rb_1", "rb_2", "op_1", "op_2", "loss"],
                                                ["op_hist", "rb_1_traj", "rb_2_traj", "op_1_traj", "op_2_traj", "op_error"]],
                                    "figsize": (15, 10)}
            elif self.env.dim == 3:
                self.sampler_mosaic = {"mosaic": [["op_1", "op_2", "op_3", "op_error"],
                                                  ["op_1_traj", "op_2_traj", "op_3_traj", "loss"],
                                                ["rb_1", "rb_2", "rb_3", "empty"],
                                                ["rb_1_traj", "rb_2_traj", "rb_3_traj", "empty"]],
                                    "figsize": (15, 10)}
            elif self.env.dim == 4:
                self.sampler_mosaic = {"mosaic": [["op_1", "op_2", "op_3", "op_4", "op_error"],
                                                  ["op_1_traj", "op_2_traj", "op_3_traj", "op_4_traj", "loss"],
                                                ["rb_1", "rb_2", "rb_3", "rb_4", "empty"],
                                                ["rb_1_traj", "rb_2_traj", "rb_3_traj", "rb_4_traj", "empty"]],
                                    "figsize": (15, 10)}
        else:
            warnings.warn("Plotting is not implemented for dimensions higher than 4.")

    def plot_reward(self, ax: plt.Axes):
        """
        Plots the reward function of the environment.
        """
        if self.env.dim == 1:
            ax.plot(self.grid.cpu().numpy(), self.env.reward_grid.cpu().numpy(), color="black")
        elif self.env.dim == 2:
            ax.contourf(self.grid[..., 0].cpu().numpy(), self.env.grid[..., 1].cpu().numpy(), self.env.reward_grid.cpu().numpy(), cmap="RdBu_r", levels=40)
        else:
            raise NotImplementedError("Plotting rewards for dimensions higher than 2 is not implemented yet.")

    def compute_forward_flow(self, traj_length, forward_model, head_idx):
        forward_flow = np.zeros((traj_length, self.env.num_grid_points[0]))
        for i in range(traj_length):
            states = torch.tensor(torch.vstack((self.env.marginal_grid[0], torch.ones(self.env.num_grid_points[0], device=self.device) * i)).T, dtype=torch.float)
            for j in range(self.env.num_grid_points[0]):
                pf_params = forward_model(states[j, :], head_idx)
                policy_mean = pf_params[0]
                forward_flow[i, j] =  policy_mean
        
        return forward_flow

    def compute_flow(self, traj_length, flow_model):
        flow = np.zeros((traj_length, self.env.num_grid_points[0]))
        for i in range(traj_length):
            states = torch.tensor(torch.vstack((self.env.marginal_grid[0], torch.ones(self.env.num_grid_points[0], device=self.device) * i)).T, dtype=torch.float)
            for j in range(self.env.num_grid_points[0]):
                flow[i, j] = flow_model(states[j, :])
        
        return flow

    def plot_forward_flow(self, ax, forward_flow):
        skip = 50
        env_subsampled = self.env.marginal_grid[0][::skip]
        traj_len, grid_length = np.shape(forward_flow)
        grid_length = len(env_subsampled) 
        midpoint_index = len(env_subsampled) // 2
        for i in range(traj_len):  # Iterate up to the second last layer
            for j in range(grid_length):  # Iterate over grid points
                current_point = (env_subsampled[j], i)
                next_point = (env_subsampled[j] + forward_flow[i, j], i + 1)
                if i == 0:
                    if j == 0:
                        ax.plot([0, forward_flow[0, midpoint_index]], [0, 1], color='b')
                else:
                    ax.plot([current_point[0], next_point[0]], [current_point[1], next_point[1]], color='b', alpha=0.3)
        ax.set_yticks([])
        ax.set_yticklabels([])

    def plot_flow(self, ax, flow, layer_wise_integrals, Z):
        """
        Plots the flow model and layer-wise integrals of the flow.
        """
        env_x = self.env.marginal_grid[0].cpu()
        traj_len, grid_length = np.shape(flow)
        for i in range(traj_len):
            ax.plot(env_x, flow[i, :] + 20*i - np.mean(flow[i,:]), color='b', alpha=0.3, linewidth=0.5)
            # plot a horizontal line at y = 2*i
            ax.plot([env_x[0], env_x[-1]], [20*i, 20*i], color='black', alpha=0.3, linewidth=0.5)
            # annotate the line
            ax.text(env_x[-1] + 2, 20*i, f"logF = {round(np.mean(flow[i,:]),3)}", va='center', ha='left', color='black', alpha=0.3)
            ax.text(env_x[-1] + 2, 20*i - 2, f"Layer integral = {round(layer_wise_integrals[i],3)}", va='center', ha='left', color='black', alpha=0.3)
        ax.set_title('Z = {}'.format(round(Z, 3)))
        ax.set_yticks([])
        ax.set_yticklabels([])

    def plot_replay_buffer(self, axs, rb, save_dir=None):
        """
        Plots the replay buffer, reward distribution and a sample of trajectories from the replay buffer.
        """
        trajs = rb.sample(batch_size=max(5000, int(rb.buffer.capacity/2)))
        axs["rb_hist"].hist(trajs.states[:, -1, 0].cpu().numpy(), bins=100, density=True, alpha=0.5, color="green")
        axs["rb_hist"].plot(self.env.marginal_grid[0].cpu().numpy(), self.env.reward_grid / self.env.reward_Z, color="black")
        axs["rb_hist"].set_ylim(0, 1)
        trajs = rb.sample(batch_size=500)
        self.plot_trajectories(axs["rb_traj"], trajs.states, color="green")
        axs["rb_hist"].set_xlim(self.env.lower_bound[0], self.env.upper_bound[0])
        axs["rb_traj"].set_xlim(self.env.lower_bound[0], self.env.upper_bound[0])
        axs["rb_hist"].spines[['right', 'top', 'bottom']].set_visible(False)
        axs["rb_traj"].spines[['right', 'top']].set_visible(False)
        if save_dir is not None:
            plt.savefig(save_dir + "/replay_buffer.png")
            plt.clf()

    def plot_onpolicy_trajs(self, axs, gfn, save_dir=None):
        """
        Plots the on-policy distribution, reward distribution and a sample of on-policy trajectories.
        """
        trajs = gfn.sample_on_policy(max(5000, int(gfn.config["replay_buffer"]["capacity"]/2)))
        axs["op_hist"].hist(trajs.states[:, -1, 0].cpu().numpy(), bins=100, density=True, alpha=0.5, color="red")
        axs["op_hist"].plot(self.env.marginal_grid[0].cpu().numpy(), self.env.reward_grid / self.env.reward_Z, color="black")
        axs["op_hist"].set_ylim(0, 1)
        trajs = gfn.sample_on_policy(500)
        self.plot_trajectories(axs["op_traj"], trajs.states, color="red")
        axs["op_hist"].set_xlim(self.env.lower_bound[0], self.env.upper_bound[0])
        axs["op_traj"].set_xlim(self.env.lower_bound[0], self.env.upper_bound[0])
        axs["op_hist"].spines[['right', 'top', 'bottom']].set_visible(False)
        axs["op_traj"].spines[['right', 'top']].set_visible(False)
        if save_dir is not None:
            plt.savefig(save_dir + "/on_policy.png")
            plt.clf()

    @staticmethod
    def plot_trajectories(ax, trajectory, linewidth=0.05, alpha=0.6, color='black'):
        """
        Plots a sample of trajectories in 1D.
        """
        n, trajectory_length, _ = trajectory.shape
        for i in range(n):
            ax.plot(
                trajectory[i, :, 0].cpu().numpy(),
                np.arange(1, trajectory_length + 1),
                alpha=alpha,
                linewidth=linewidth,
                color=color,
            )
            ax.set_ylabel('Step')

    @staticmethod
    def plot_evolution(ax, iterations, quantity, quantity_name, save_dir=None, log=True):
        """
        Plots the evolution of a quantity over iterations.
        """
        ax.plot(iterations, quantity, label=quantity_name, color="black", linewidth=0.7)
        ax.set_xlabel("iteration")
        ax.set_ylabel(quantity_name)
        if log:
            ax.set_yscale("log")
            ax.set_xscale("log")
        if save_dir is not None:
            plt.savefig(save_dir + "/{}.png".format(quantity_name))
            plt.clf()

    @staticmethod
    def plot_noise_schedule(ax, gfn):
        """Plots the noise schedule of the flow model."""
        ax.plot(gfn.off_policy_noise_schedule, color="black")

    def plot_potential(self, ax, mds=None):
        if mds is None:
            ax.contourf(self.env.grid[..., 0].cpu().numpy(), self.env.grid[..., 1].cpu().numpy(), self.env.reward_grid.cpu().numpy(), cmap="RdBu_r", levels=40)
        else:
            raise NotImplementedError("Plotting the potential for metadynamics is not implemented yet.")

    def plot_trajectories_grid(self, axs: dict[str, plt.Axes], plot_type, sampler, title, color_map: str, color: str, save_dir=None):
        """Plots the trajectories of the replay buffer (if plot_type = "replay_buffer") or the on-policy samples (if plot_type = "on_policy") on the phi-psi grid."""
       
        if plot_type == "rb":
            large_traj = sampler.sample(batch_size=5000)
        else:
            large_traj = sampler.sample_on_policy(batch_size=5000)
        
        if self.env.dim == 2:
            axs[f"{plot_type}_hist"].hist2d(large_traj.states[:, -1, 0].cpu().numpy(), large_traj.states[:, -1, 1].cpu().numpy(), bins=self.env.grid_bins, density=True, cmap=color_map)
            axs[f"{plot_type}_hist"].set_xlim(-self.env.lower_bound[0], self.env.upper_bound[0])
            axs[f"{plot_type}_hist"].set_ylim(-self.env.lower_bound[1], self.env.upper_bound[1])
            axs[f"{plot_type}_1"].hist(large_traj.states[:, -1, 0].cpu().numpy(), bins=self.env.grid_bins[0], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_2"].hist(large_traj.states[:, -1, 1].cpu().numpy(), bins=self.env.grid_bins[1], density=True, color=color, alpha=0.5, linewidth=0.05)
            plot_list = ["1_traj", "2_traj"]
        elif self.env.dim == 3:
            axs[f"{plot_type}_1"].hist(large_traj.states[:, -1, 0].cpu().numpy(), bins=self.env.grid_bins[0], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_2"].hist(large_traj.states[:, -1, 1].cpu().numpy(), bins=self.env.grid_bins[1], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_3"].hist(large_traj.states[:, -1, 2].cpu().numpy(), bins=self.env.grid_bins[2], density=True, color=color, alpha=0.5, linewidth=0.05)
            plot_list = ["1_traj", "2_traj", "3_traj"]
        elif self.env.dim == 4:
            axs[f"{plot_type}_1"].hist(large_traj.states[:, -1, 0].cpu().numpy(), bins=self.env.grid_bins[0], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_2"].hist(large_traj.states[:, -1, 1].cpu().numpy(), bins=self.env.grid_bins[1], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_3"].hist(large_traj.states[:, -1, 2].cpu().numpy(), bins=self.env.grid_bins[2], density=True, color=color, alpha=0.5, linewidth=0.05)
            axs[f"{plot_type}_4"].hist(large_traj.states[:, -1, 3].cpu().numpy(), bins=self.env.grid_bins[3], density=True, color=color, alpha=0.5, linewidth=0.05)
            plot_list = ["1_traj", "2_traj", "3_traj", "4_traj"]

        if plot_type == "rb":
            traj = sampler.sample(batch_size=500)
        else:
            traj = sampler.sample_on_policy(batch_size=500)

        for plot_num, key in enumerate(plot_list):
            for i in range(traj.batch_size):
                axs[plot_type+"_"+key].plot(
                    traj.states[i, :, plot_num].cpu().numpy(),
                    np.arange(1, traj.length + 2),
                    alpha=0.1,
                    linewidth=0.05,
                    color=color,
                )
                axs[plot_type+"_"+key].set_ylabel('Step')

        if save_dir is not None:
            plt.savefig(save_dir + f"/{title}.png")
            plt.clf()

    def plot_metadynamics(self, axs: dict[str, plt.Axes], mds, save_dir=None):
        """Plots a snapshot of the metadynamics algorithm using the current state of the Metadynamics sampler object (mds)."""

        if self.env.dim == 1:
            for metad_point in mds.z:
                axs["metad"].plot(metad_point, 0, "ro", markersize=3)
            axs["metad"].plot(self.env.marginal_grid[0].cpu().numpy(), mds.confining_potential, label=r"$V_t$", color="black")
            axs["metad"].plot(self.env.marginal_grid[0].cpu().numpy(), mds.bias_potential, label=r"$V^{bias}_t", color="red")
            axs["metad"].plot(self.env.marginal_grid[0].cpu().numpy(), self.env.reward_grid, color="black", alpha=0.5, ls='--')
            axs["metad"].set_title("Metadynamics")
            if save_dir is not None:
                self.global_data.append((mds.iteration_number, np.copy(mds.confining_potential), np.copy(mds.bias_potential)))
                plt.clf()
                for it_list, (it, confining_potential, bias_potential) in enumerate(self.global_data):
                    plt.plot(self.env.marginal_grid[0].cpu().numpy(), confining_potential, label=str(it), color=(0, 0, 0, it_list/len(self.global_data)))
                    plt.plot(self.env.marginal_grid[0].cpu().numpy(), bias_potential, label=str(it), color=(1, 0, 0, it_list/len(self.global_data)))
                
                plt.plot(self.env.marginal_grid[0].cpu().numpy(), self.env.reward_grid, color="black", alpha=0.5, ls='--')
                # for metad_point in mds.z:
                    # plt.plot(metad_point, 0, "ro", markersize=3)
                # set the figure size
                # plt.gcf().set_size_inches(5, 2.5)
                # plt.legend()
                plt.savefig(save_dir + "/potential_and_bias_" + str(mds.iteration_number) + ".pdf")
                plt.clf()
        else:
            # Contour plot of the on-policy samples
            self.plot_potential(axs["samples_on_potential"])
            # Add metadynamics samples to the plot, showing their positions and momenta
            for metad_point, metad_momentum in zip(mds.z, mds.p):
                axs["samples_on_potential"].scatter(metad_point[0].item(), metad_point[1].item(), c="r", s=7, marker="o")
                axs["samples_on_potential"].quiver(metad_point[0].item(), metad_point[1].item(), metad_momentum[0].item(), metad_momentum[1].item(), color="red", scale=2, units="xy")

            # Contour plot of the current kde confining potential
            axs["kde_potential"].contour(mds.confining_potential, cmap="RdBu_r", levels=40)

            # Add metadynamics samples to the plot, showing their positions and momenta and the instantaneous forces acting on them
            for metad_point, metad_momentum in zip(mds.z, mds.p):
                rescale = [lambda x: (x + (self.env.upper_bound[i] - self.env.lower_bound[i])/2) / (self.env.upper_bound[i] - self.env.lower_bound[i]) * len(self.env.marginal_grid[i]) for i in range(self.env.dim)]
                idx = torch.div((metad_point[0] - self.env.lower_bound), self.env.grid_spacing).floor().long() - 1
                try:
                    F0_confining = - mds.grad_confining_potential[0, idx[1], idx[0]]
                    F1_confining = - mds.grad_confining_potential[1, idx[1], idx[0]]
                    axs["kde_potential"].scatter(rescale[0](metad_point[0].item()), rescale[1](metad_point[1].item()), c="r", s=7, marker="o")
                    axs["kde_potential"].quiver(rescale[0](metad_point[0].item()), rescale[1](metad_point[1].item()), F0_confining.item(), F1_confining.item(), color="blue", scale=0.2, units="xy")
                except:
                    pass

            # Plot of theta derivative of the confining potential
            try:
                max_abs0 = np.abs(mds.grad_confining_potential[0, :, :]).max()
                levels = np.linspace(-max_abs0, max_abs0, 40)
                axs["kde_potential_prime1"].contour(mds.grad_confining_potential[0, :, :], cmap="RdBu_r", levels=levels)

                # Plot of theta_dot derivative of the confining potential
                max_abs1 = np.abs(mds.grad_confining_potential[0, :, :]).max()
                levels = np.linspace(-max_abs1, max_abs1, 40)
                axs["kde_potential_prime2"].contour(mds.grad_confining_potential[1, :, :], cmap="RdBu_r", levels=levels)
            except:
                pass

            # Plot of the bias potential
            axs["bias_potential"].contour(mds.bias_potential, cmap="RdBu_r", levels=40)
            if save_dir is not None:
                plt.savefig(save_dir + "/metad.png")
                plt.clf()

    def plot_sampler(self, iterations: np.array, mds, rb, out, iter_idx: int, losses, logZs, policy_L1_error=None):
        """Master function for plotting the metadynamics algorithm, replay buffer, and on-policy samples."""

        fig, axs = plt.subplot_mosaic(**self.sampler_mosaic)
        out.clear_output()

        # Flow plots
        flow_model = None
        if self.env.dim == 1 and (self.env.config["gfn"]["loss"] in ["STB", "DB"]):
            flow_model = self.gfn.logF_model 
            with torch.no_grad():
                self.gfn.forward_model.eval()
                forward_flow = self.compute_forward_flow(self.gfn.trajectory_length, self.gfn.forward_model, head_idx=0)
                flow_model.eval()
                flow = self.compute_flow(self.gfn.trajectory_length, flow_model)
                layer_wise_integrals = np.mean(np.exp(flow), axis=1) * self.env.grid_spacing[0].cpu().numpy()
                self.gfn.forward_model.train()
                flow_model.train()

        with out:
            buffer = BytesIO()
            print("Iteration {}".format(iter_idx))
            if self.config["metad"]["active"] and self.env.dim < 3:
                self.plot_metadynamics(axs, mds)#, save_dir="./exps/rebuttal_figures")
            if self.env.dim == 2 or self.env.dim == 3 or self.env.dim == 4:
                try:
                    self.plot_trajectories_grid(axs, "rb", rb, "replay_buffer", "Greens", "green")
                except Exception as e:
                    print("Replay buffer plotting failed.")
                    print(e)
                try:
                    self.plot_trajectories_grid(axs, "op", self.gfn, "on_policy", "Reds", "red")
                except Exception as e:
                    print("On-policy plotting failed.")
                    print(e)

            if self.env.dim < 3:
                self.plot_noise_schedule(axs["noise_schedule"], self.gfn)
            
            self.plot_evolution(axs["loss"], iterations=iterations, quantity=losses, quantity_name="Loss", log=True)
            self.plot_evolution(axs["op_error"], iterations=iterations, quantity=policy_L1_error, quantity_name="Policy L1 error", log=False)

            if self.env.dim == 1:
                if flow_model is not None:
                    self.plot_forward_flow(axs["forward_flow"], forward_flow)
                    self.plot_flow(axs["flow"], flow, layer_wise_integrals, np.exp(logZs[-1]))
                else:
                    axs["flow"].axis('off')
                self.plot_onpolicy_trajs(axs, self.gfn)
                if mds.config["replay_buffer"]["active"]:
                    self.plot_replay_buffer(axs, rb)
                else:
                    axs["rb_hist"].axis('off')
                self.plot_evolution(axs["loss"], iterations=iterations, quantity=losses, quantity_name="Loss", log=True)
                self.plot_evolution(axs["Z"], iterations=iterations, quantity=np.exp(logZs), quantity_name="Z", log=False)
                self.plot_noise_schedule(axs["noise_schedule"], self.gfn)
                self.plot_evolution(axs["op_error"], iterations=iterations, quantity=policy_L1_error, quantity_name="Policy L1 Error", log=False)
                axs['empty'].axis('off')

            plt.savefig(buffer, format="png")
            buffer.seek(0)
            display(Image(buffer.read(), format='png'))
            plt.clf()

        option = input("Enter 'q' to quit or the number of iterations to run: ")
        if option == "q":
            return -1
        else:
            return int(option)